import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario


class Scenario(BaseScenario):
    def __init__(self, num_agents=3, dist_threshold=0.1, arena_size=1, identity_size=0):
        self.num_agents = num_agents
        #self.rewards = np.zeros(self.num_agents)
        #self.temp_done = False
        self.dist_threshold = dist_threshold
        self.arena_size = arena_size
        self.identity_size = identity_size
        self.flag = [False,False]
        self.count_1 = 0
        self.count_2 = 0
        #self.dec_count = 0
        self.dec_count_flag = [False,False]
        self.caught = 0
        self.rew_other = 0
    def make_world(self):
        world = World()
        # set any world properties first
        world.dim_c = 2
        self.num_good_agents = 4
        self.num_adversaries = 7
        num_agents = self.num_adversaries + self.num_good_agents
        num_landmarks = 2
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.collide = True
            agent.silent = True
            #agent.adversary = True if i < num_adversaries else False
            agent.adversary = True if i < self.num_adversaries else False
            agent.size = 0.075 if agent.adversary else 0.05
            #agent.accel = 3.0 if agent.adversary else 4.0
            agent.accel = 3.5 if agent.adversary else 3
            #agent.accel = 20.0 if agent.adversary else 25.0
            #agent.max_speed = 1.0 if agent.adversary else 1.3
            agent.max_speed = 1.0 if agent.adversary else 0.8
            agent.indx = i
            agent.caught = False
            #agent.life = 1
            agent.live = 1
            agent.live_adv = 1
        # add landmarks
        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = 'landmark %d' % i
            landmark.collide = True
            landmark.movable = False
            landmark.size = 0.2
            landmark.boundary = False
        # make initial conditions
        self.reset_world(world)
        return world


    def reset_world(self, world):
        # random properties for agents
        #print(world.steps)
        for i, agent in enumerate(world.agents):
            agent.color = np.array([0.35, 0.85, 0.35]) if not agent.adversary else np.array([0.85, 0.35, 0.35])
            # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])
        # set random initial states
        for agent in world.agents:
            agent.state.p_pos = np.random.uniform(-self.arena_size, +self.arena_size, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
            agent.caught = False
            #agent.life = 1
            agent.live = 1
            agent.live_adv = 1
            agent.collide = True
            agent.movable = True
        for i, landmark in enumerate(world.landmarks):
            if not landmark.boundary:
                landmark.state.p_pos = np.random.uniform(-0.8*self.arena_size, +0.8*self.arena_size, world.dim_p)
                landmark.state.p_vel = np.zeros(world.dim_p)
        self.flag = [False,False]
        self.count_1 = 0
        self.count_2 = 0
        #self.dec_count = 0
        self.dec_count_flag = [False,False]
        world.steps = 0
        self.caught = 0
        self.rew_other = 0
        self.is_success = False



    def benchmark_data(self, agent, world):
        # returns data for benchmarking purposes
        if agent.adversary:
            collisions = 0
            for a in self.good_agents(world):
                if self.is_collision(a, agent):
                    collisions += 1
            return collisions
        else:
            return 0


    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist < dist_min else False

    def is_dectective(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        #dist_min = agent1.size + agent2.size
        dist_dec = 1
        return True if dist < dist_dec else False

    # return all agents that are not adversaries
    def good_agents(self, world):
        return [agent for agent in world.agents if not agent.adversary]

    # return all adversarial agents
    def adversaries(self, world):
        return [agent for agent in world.agents if agent.adversary]


    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
        return main_reward

    def agent_reward(self, agent, world):
        # Agents are negatively rewarded if caught by adversaries
        rew = 0
        shape = False
        adversaries = self.adversaries(world)
        #if shape:  # reward can optionally be shaped (increased reward for increased distance from adversary)
        #    if self.flag[agent.indx-5] == True:
        #        rew += 0
        #    if self.flag[agent.indx-5] == False:
        #       for adv in adversaries:
        #           rew += 0.05 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
        if agent.live == 1:
            if shape:  # reward can optionally be shaped (increased reward for increased distance from adversary)
                for adv in adversaries:
                    rew += 0.01 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
            if agent.collide:
                # the more collisions  with the adversaries agents, the more negative reward
                for a in adversaries:
                    if self.is_collision(a, agent):
                        rew -= 10
                        agent.live = 0
        """
        if agent.collide:
            for a in adversaries:
                if self.is_collision(a, agent):
                    if agent.life == 1:
                       rew -= 10
                       agent.collide = False
                    #agent.movable = False
                    #if self.flag[agent.indx-5] == False:
                     #  self.flag[agent.indx-5] = True
                    #if self.flag[agent.indx-5] == True:
                     #  rew -= 0
                    #rew -= 10
        """
        # agents are penalized for exiting the screen, so that they can be caught by the adversaries
        def bound(x):
            if x < 0.9:
                return 0
            if x < 1.0:
                return (x - 0.9) * 10
            return min(np.exp(2 * x - 2), 10)
        for p in range(world.dim_p):
            x = abs(agent.state.p_pos[p])
            rew -= bound(x)
        #print(self.flag,"good")

        return rew

    def adversary_reward(self, agent, world):
        # Adversaries are rewarded for collisions with agents
        rew = 0
        shape = False
        live_agents = 0
        agents = self.good_agents(world)
        adversaries = self.adversaries(world)
        """
        if shape:  # reward can optionally be shaped (decreased reward for increased distance from agents)
           for good in agents:
               #if self.flag[good.indx-5] == True:
             #      rew -= 0
               #if self.flag[good.indx-5] == False:
                #  for adv in adversaries:
                #      rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
                      #rew -= 0.1 * np.sqrt(np.sum(np.square(good.state.p_pos - adv.state.p_pos)))
               if good.caught == True:
                  rew -= 0
               if good.caught == False:
                  rew -= 0.1 * min([np.sqrt(np.sum(np.square(good.state.p_pos - adv.state.p_pos))) for adv in adversaries])
                  #for adv in adversaries:
                      #rew -= 0.1 * np.sqrt(np.sum(np.square(good.state.p_pos - adv.state.p_pos)))
                    #  rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
                      #np.linalg.norm(good.state.p_pos - adv.state.p_pos)
            #for adv in adversaries:
            #    rew -= 0.05 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
        """
        if shape:  # reward can optionally be shaped (decreased reward for increased distance from agents)
            for adv in adversaries:
                temp = [np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents if a.live == 1]
                if len(temp) != 0:
                    rew -= 0.1*min(temp)
        if agent.collide:
            for ag in agents:
                for adv in adversaries:
                    if ag.live == 1:
                        if self.is_collision(ag, adv):
                            rew += 1
                            ag.live_adv = 0
                            ag.caught == True
                    """
                    if self.is_collision(ag, adv):
                       ag.caught == True
                       ag.collide = False
                       #ag.movable = False
                       if ag.live == 1:
                          rew += 10
                          ag.live_adv = 0
                          self.caught += 1
                          #print(ag.indx,ag.life)
                          #print(self.caught)
                          #print(ag.indx)
                       #if ag.life == 0:
                        #  rew += 0
                       #print(ag.indx,ag.life)
                       if self.caught == 1 and self.count_1 ==0 :
                          rew += 0
                          #rew_other = 0
                          #rew_other = 5
                          self.rew_other = 0
                          self.count_1 += 1
                       if self.caught == 2 and self.count_1 ==1 :
                          rew += 0
                          #rew_other += 7
                          self.rew_other +=20
                          self.count_1 += 1
                       #if self.caught == 3 and self.count_1 ==2 :
                        #  rew += 0
                          #rew_other += 9
                         # self.rew_other +=6
                          #self.count_1 += 1
                       #if self.caught == 4 and self.count_1 ==3 :
                        #  rew += 20
                          #rew_other += 11
                        #  self.rew_other +=11
                         # self.count_1 += 1
                    """
        #print(self.count_1)
        #print(self.count_1)
        #print(self.rew_other)
        #print(self.caught)
        #if self.count_1 in range(1,3) and agent.indx <= 4:
        #   rew += self.rew_other
        #if self.count_1 > 2:
        #   self.rew_other = 0
        #   rew += 0
        for agg in agents:
            # print('存活与否')
            # print(ag.live)
            if agg.live_adv == 0:
                live_agents = live_agents + 1
        if live_agents == 2:
            # print('存活数量')
            # print(live_agents)
            # rew -= world.steps
            rew += 2
            self.is_success = True
            # rew += 0.5 * (world.max_steps_episode - world.steps)
        else:
            self.is_success = False

        #if live_agents == 3:
        #   rew += 4
        #if live_agents == 4:
        #   rew += 6
        #   self.is_success = True
        #else:
        #   self.is_success = False
        #if self.caught == 3:
        #   print(agent.indx, rew)


                       #if self.flag[ag.indx-5] == False:
                        #  rew += 15
                         # self.flag[ag.indx-5] = True
                       #if self.flag[ag.indx-5] == True:
                        #  rew += 0
                       #if ag.indx = 5:
                        #  if self.flag[]
        """
        def bound(x):
            if x < 0.9:
                return 0
            if x < 1.0:
                return (x - 0.9) * 10
            return min(np.exp(2 * x - 2), 10)
        for p in range(world.dim_p):
            x = abs(agent.state.p_pos[p])
            rew -= bound(x)
        """

        #print(self.flag,"adv")
        #print(agent.indx, rew)

        return rew

    def observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        #print(type(agent.state.p_vel))
        for entity in world.landmarks:
            if not entity.boundary:
                entity_pos.append(entity.state.p_pos - agent.state.p_pos)
        # communication of all other agents
        comm = []
        other_pos = []
        other_vel = []
        #self.dec_count_flag = [False,False]
        self.dec_count = 0
        adversaries1 = self.adversaries(world)
            #other_pos.append(other.state.p_pos - agent.state.p_pos)
            #if not other.adversary:
            #other_vel.append(other.state.p_vel)
        """
        if agent.adversary:
           for other in world.agents:
               #print(other.adversary,other.indx,'fuck1')
               if other is agent: continue
               comm.append(other.state.c)
               if not other.adversary:
                  if self.is_dectective(agent, other):
                     #print(other.indx)
                     other_pos.append(other.state.p_pos - agent.state.p_pos)
                     other_vel.append(other.state.p_vel)
                     #self.dec_count_flag[other.indx-5] = True
                     self.dec_count += 1
                     #print(self.dec_count_flag)
           #print('fuck2')
        if not agent.adversary:
           for other in world.agents:
               if other is agent: continue
               comm.append(other.state.c)
               #if not other.adversary:
               other_pos.append(other.state.p_pos - agent.state.p_pos)
            #if not other.adversary:
               other_vel.append(other.state.p_vel)
        #return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos)
        """

        """
        for other in world.agents:
            #print(other.adversary,other.indx,'fuck1')
            if other is agent: continue
            comm.append(other.state.c)
            if self.is_dectective(agent, other):
                  #print(other.indx)
               if other.life == 0:
                  other_pos.append(np.array([0,0]))
                  other_vel.append(np.array([0,0]))
               else:
                  other_pos.append(other.state.p_pos - agent.state.p_pos)
                  other_vel.append(other.state.p_vel)
                  #self.dec_count_flag[other.indx-5] = True
               self.dec_count += 1

                                                  # 1+4+6
              #default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx])
           #return default_obs

        xx = [0,0]
        xx = np.array(xx)
        #if self.dec_count_flag == [False,False]:
         #  default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx])
           #print(default_obs.shape,'fuck1')
        #if self.dec_count_flag == [True,False] or self.dec_count_flag == [False,True]:
         #  default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] +[xx] + [xx])
           #print(default_obs.shape,'fuck2')
        #if self.dec_count_flag == [True,True]:
         #  default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx])
           #print(default_obs.shape,'fuck3')

        if self.dec_count == 0:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                          [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx]
                                          + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] )    #1+0+10
           #print(len(default_obs),'fuck0')
        if self.dec_count == 1:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                          [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx]
                                          + [xx] + [xx] + [xx] + [xx] ) # 1+1+9
        if self.dec_count == 2:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                         [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx]
                                         + [xx] + [xx] ) # 1 + 2+8
        if self.dec_count == 3:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx]
                                        )         # 1+ 3+ 7
        if self.dec_count == 4:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] )
        if self.dec_count == 5:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] + [xx] )
        if self.dec_count == 6:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx] + [xx] + [xx] + [xx] + [xx]  )
        if self.dec_count == 7:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx] + [xx] + [xx]  )
        if self.dec_count == 8:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])] +
                                        [xx] + [xx]  )
        if self.dec_count == 9:
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])]
                                        )
        """
        #if self.dec_count == 10:
         #  default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])])

        #if not agent.adversary:
         #  default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + [np.array([0,agent.indx])])
           #print(len(default_obs),'fuck1')
        #default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + [np.array([0,agent.indx])])
        xx = [0,0]
        xx = np.array(xx)
        if agent.adversary:
           for other in world.agents:
               if other is agent: continue
               comm.append(other.state.c)
               #if other.live == 0:
                #  other_pos.append(np.array([0,0]))
                 # other_vel.append(np.array([0,0]))
               #else:
               other_pos.append(other.state.p_pos - agent.state.p_pos)
               other_vel.append(other.state.p_vel)
           if self.num_good_agents == 2:
              default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + [xx] + [xx] + other_vel + [xx] + [xx] + [np.array([0,agent.live])])
           if self.num_good_agents == 3:
              default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + [xx] + other_vel + [xx] + [np.array([0,agent.live])])
           if self.num_good_agents == 4:
              default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel  + [np.array([0,agent.live])])
        if not agent.adversary:
           for other in adversaries1:
               if other is agent: continue
               comm.append(other.state.c)
               other_pos.append(other.state.p_pos - agent.state.p_pos)
               other_vel.append(other.state.p_vel)
           default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos  + other_pos + [xx] + [xx] + [xx] + other_vel + [xx] + [xx] + [xx] + [np.array([0,agent.live])])

        #default_obs = default_obs.tolist()
        #print(default_obs.shape)
        #print(default_obs.shape)
        return default_obs

    def done(self, agent, world):
        condition1 = world.steps >= world.max_steps_episode
        #print('fuck')
        #success_list = []
        #agents = self.good_agents(world)
        #adversaries = self.adversaries(world)
        #for adv in adversaries:
        #    for a in agents:
        #        if np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) < 0.125:
        #            success_list.append(True)
                    #flag[i] = True
        #success_list = np.array(success_list)
        #self.is_success = np.any(success_list == True)
        #print(self.is_success)
        #if self.flag[0] == True and self.flag[1] == True:
        #    self.is_success = True
        #else:
        #    self.is_success = False
        #if self.caught == 2:
        #   self.is_success = True
        #else:
        #   self.is_success = False
        #self.is_success = np.all(self.flag == True)
        #self.is_success = np.all(self.min_dists < world.dist_thres)
        return condition1 or self.is_success

    def info(self, agent, world):
        return {'is_success': self.is_success, 'world_steps': world.steps}
                #'reward':self.joint_reward, 'dists':self.delta_dists.mean()}
